1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Mutex};
14
15use tokio::sync::mpsc;
16use uuid::Uuid;
17
18use super::super::rpc_types::{RpcContext, RpcMessage, RpcOpts, COMMAND_EVENT_RECV};
19
20pub const DEFAULT_TIMEOUT_MS: i64 = 5000;
23const RESP_CH_SIZE: usize = 32;
24
25pub type HandlerResult = Result<Option<serde_json::Value>, String>;
29
30pub type CommandHandler = Box<
36 dyn Fn(serde_json::Value, RpcContext) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>>
37 + Send
38 + Sync,
39>;
40
41pub type StreamHandler = Box<
43 dyn Fn(
44 serde_json::Value,
45 RpcContext,
46 )
47 -> Pin<Box<dyn Future<Output = Result<mpsc::Receiver<HandlerResult>, String>> + Send>>
48 + Send
49 + Sync,
50>;
51
52enum Handler {
53 Call(CommandHandler),
54 #[allow(dead_code)]
55 Stream(StreamHandler),
56}
57
58pub struct RpcResponseHandler {
63 engine: Arc<WshRpcEngine>,
64 req_id: String,
65 #[allow(dead_code)]
66 source: String,
67 canceled: AtomicBool,
68 done: AtomicBool,
69}
70
71impl RpcResponseHandler {
72 pub fn send_response(&self, data: Option<serde_json::Value>, done: bool) {
75 if self.done.load(Ordering::Relaxed) {
76 return;
77 }
78 let msg = RpcMessage {
79 resid: self.req_id.clone(),
80 data,
81 cont: !done,
82 ..Default::default()
83 };
84 if done {
85 self.done.store(true, Ordering::Relaxed);
86 }
87 self.engine.send_output(msg);
88 }
89
90 pub fn send_error(&self, err: &str) {
92 if self.done.load(Ordering::Relaxed) {
93 return;
94 }
95 self.done.store(true, Ordering::Relaxed);
96 let msg = RpcMessage {
97 resid: self.req_id.clone(),
98 error: err.to_string(),
99 ..Default::default()
100 };
101 self.engine.send_output(msg);
102 }
103
104 #[allow(dead_code)]
106 pub fn is_canceled(&self) -> bool {
107 self.canceled.load(Ordering::Relaxed)
108 }
109
110 #[allow(dead_code)]
112 pub fn get_source(&self) -> &str {
113 &self.source
114 }
115
116 fn cancel(&self) {
118 self.canceled.store(true, Ordering::Relaxed);
119 }
120
121 fn finalize(&self) {
123 if self.done.load(Ordering::Relaxed) {
124 return;
125 }
126 self.send_response(None, true);
127 }
128}
129
130#[allow(dead_code)]
135pub struct RpcRequestHandler {
136 req_id: String,
137 resp_rx: mpsc::Receiver<RpcMessage>,
138 last_was_cont: bool,
139}
140
141impl RpcRequestHandler {
142 #[allow(dead_code)]
144 pub async fn next_response(&mut self) -> Option<Result<serde_json::Value, String>> {
145 if !self.last_was_cont && self.req_id.is_empty() {
146 return None;
147 }
148 match self.resp_rx.recv().await {
149 Some(msg) => {
150 self.last_was_cont = msg.cont;
151 if !msg.error.is_empty() {
152 Some(Err(msg.error))
153 } else {
154 Some(Ok(msg.data.unwrap_or(serde_json::Value::Null)))
155 }
156 }
157 None => None,
158 }
159 }
160
161 #[allow(dead_code)]
163 pub fn is_done(&self) -> bool {
164 !self.last_was_cont
165 }
166
167 #[allow(dead_code)]
169 pub fn req_id(&self) -> &str {
170 &self.req_id
171 }
172}
173
174struct EngineInner {
177 handlers: HashMap<String, Handler>,
178 pending_responses: HashMap<String, mpsc::Sender<RpcMessage>>,
179 active_handlers: HashMap<String, Arc<RpcResponseHandler>>,
180 #[allow(dead_code)]
181 auth_token: String,
182 rpc_context: Option<RpcContext>,
183}
184
185pub struct WshRpcEngine {
190 inner: Mutex<EngineInner>,
191 output_tx: mpsc::UnboundedSender<RpcMessage>,
192}
193
194impl WshRpcEngine {
195 pub fn new() -> (Arc<Self>, mpsc::UnboundedReceiver<RpcMessage>) {
198 let (output_tx, output_rx) = mpsc::unbounded_channel();
199 let engine = Arc::new(Self {
200 inner: Mutex::new(EngineInner {
201 handlers: HashMap::new(),
202 pending_responses: HashMap::new(),
203 active_handlers: HashMap::new(),
204 auth_token: String::new(),
205 rpc_context: None,
206 }),
207 output_tx,
208 });
209 (engine, output_rx)
210 }
211
212 pub fn register_handler(&self, command: &str, handler: CommandHandler) {
214 let mut inner = self.inner.lock().unwrap();
215 inner
216 .handlers
217 .insert(command.to_string(), Handler::Call(handler));
218 }
219
220 #[allow(dead_code)]
222 pub fn register_stream_handler(&self, command: &str, handler: StreamHandler) {
223 let mut inner = self.inner.lock().unwrap();
224 inner
225 .handlers
226 .insert(command.to_string(), Handler::Stream(handler));
227 }
228
229 #[allow(dead_code)]
231 pub fn set_auth_token(&self, token: &str) {
232 let mut inner = self.inner.lock().unwrap();
233 inner.auth_token = token.to_string();
234 }
235
236 #[allow(dead_code)]
238 pub fn get_auth_token(&self) -> String {
239 let inner = self.inner.lock().unwrap();
240 inner.auth_token.clone()
241 }
242
243 #[allow(dead_code)]
245 pub fn set_rpc_context(&self, ctx: RpcContext) {
246 let mut inner = self.inner.lock().unwrap();
247 inner.rpc_context = Some(ctx);
248 }
249
250 pub fn handle_message(self: &Arc<Self>, msg: RpcMessage) {
252 if msg.cancel {
254 if !msg.reqid.is_empty() {
255 self.handle_cancel_request(&msg.reqid);
256 }
257 return;
258 }
259
260 if msg.command == COMMAND_EVENT_RECV {
262 return;
264 }
265
266 if !msg.command.is_empty() {
268 let engine = self.clone();
269 tokio::spawn(async move {
270 engine.handle_request(msg).await;
271 });
272 return;
273 }
274
275 if !msg.resid.is_empty() {
277 self.handle_response(msg);
278 }
279 }
280
281 #[allow(dead_code)]
283 pub async fn send_command(
284 self: &Arc<Self>,
285 command: &str,
286 data: serde_json::Value,
287 opts: &RpcOpts,
288 ) -> Result<serde_json::Value, String> {
289 let mut handler = self.send_request(command, data, opts)?;
290 match handler.next_response().await {
291 Some(result) => result,
292 None => Err("no response received".to_string()),
293 }
294 }
295
296 #[allow(dead_code)]
298 pub fn send_request(
299 self: &Arc<Self>,
300 command: &str,
301 data: serde_json::Value,
302 opts: &RpcOpts,
303 ) -> Result<RpcRequestHandler, String> {
304 let req_id = Uuid::new_v4().to_string();
305 let (resp_tx, resp_rx) = mpsc::channel(RESP_CH_SIZE);
306
307 {
308 let mut inner = self.inner.lock().unwrap();
309 inner
310 .pending_responses
311 .insert(req_id.clone(), resp_tx);
312 }
313
314 let timeout = if opts.timeout > 0 {
315 opts.timeout
316 } else {
317 DEFAULT_TIMEOUT_MS
318 };
319 let route = if opts.route.is_empty() {
320 String::new()
321 } else {
322 opts.route.clone()
323 };
324
325 let msg = RpcMessage {
326 command: command.to_string(),
327 reqid: req_id.clone(),
328 timeout,
329 route,
330 data: Some(data),
331 authtoken: self.get_auth_token(),
332 ..Default::default()
333 };
334 self.send_output(msg);
335
336 Ok(RpcRequestHandler {
337 req_id,
338 resp_rx,
339 last_was_cont: true, })
341 }
342
343 #[allow(dead_code)]
345 pub fn send_command_no_response(
346 &self,
347 command: &str,
348 data: serde_json::Value,
349 route: &str,
350 ) {
351 let msg = RpcMessage {
352 command: command.to_string(),
353 data: Some(data),
354 route: route.to_string(),
355 authtoken: self.get_auth_token(),
356 ..Default::default()
357 };
358 self.send_output(msg);
359 }
360
361 fn send_output(&self, msg: RpcMessage) {
364 let _ = self.output_tx.send(msg);
365 }
366
367 async fn handle_request(self: Arc<Self>, msg: RpcMessage) {
368 let request_start = std::time::Instant::now();
369 let timeout_ms = if msg.timeout > 0 {
370 msg.timeout
371 } else {
372 DEFAULT_TIMEOUT_MS
373 };
374
375 let handler = Arc::new(RpcResponseHandler {
376 engine: self.clone(),
377 req_id: msg.reqid.clone(),
378 source: msg.source.clone(),
379 canceled: AtomicBool::new(false),
380 done: AtomicBool::new(false),
381 });
382
383 if !msg.reqid.is_empty() {
385 let mut inner = self.inner.lock().unwrap();
386 inner
387 .active_handlers
388 .insert(msg.reqid.clone(), handler.clone());
389 }
390
391 let rpc_context = {
392 let inner = self.inner.lock().unwrap();
393 inner.rpc_context.clone().unwrap_or_default()
394 };
395
396 let data = msg.data.unwrap_or(serde_json::Value::Null);
397 let command = msg.command.clone();
398
399 let has_call;
401 let has_stream;
402 {
403 let inner = self.inner.lock().unwrap();
404 match inner.handlers.get(&command) {
405 Some(Handler::Call(_)) => {
406 has_call = true;
407 has_stream = false;
408 }
409 Some(Handler::Stream(_)) => {
410 has_call = false;
411 has_stream = true;
412 }
413 None => {
414 has_call = false;
415 has_stream = false;
416 }
417 }
418 }
419
420 let dispatch_elapsed = request_start.elapsed();
421
422 if !has_call && !has_stream {
423 handler.send_error(&format!("unknown command: {}", command));
424 self.cleanup_handler(&msg.reqid);
425 return;
426 }
427
428 let timeout_dur = std::time::Duration::from_millis(timeout_ms as u64);
429
430 if has_call {
431 let handler_start = std::time::Instant::now();
434 let fut = {
435 let inner = self.inner.lock().unwrap();
436 match inner.handlers.get(&command) {
437 Some(Handler::Call(h)) => h(data.clone(), rpc_context.clone()),
438 _ => Box::pin(async { Err("handler disappeared".to_string()) }),
439 }
440 };
441 let result = tokio::time::timeout(timeout_dur, fut).await;
442 let handler_elapsed = handler_start.elapsed();
443 let total_elapsed = request_start.elapsed();
444
445 tracing::info!(
446 "[rpc-perf] command={} dispatch={:.2}ms handler={:.2}ms total={:.2}ms",
447 command,
448 dispatch_elapsed.as_secs_f64() * 1000.0,
449 handler_elapsed.as_secs_f64() * 1000.0,
450 total_elapsed.as_secs_f64() * 1000.0,
451 );
452
453 match result {
454 Ok(Ok(resp_data)) => handler.send_response(resp_data, true),
455 Ok(Err(err)) => handler.send_error(&err),
456 Err(_) => handler.send_error(&format!("EC-TIME: timeout ({}ms)", timeout_ms)),
457 }
458 } else {
459 let fut = {
461 let inner = self.inner.lock().unwrap();
462 match inner.handlers.get(&command) {
463 Some(Handler::Stream(h)) => h(data.clone(), rpc_context.clone()),
464 _ => Box::pin(async { Err("handler disappeared".to_string()) }),
465 }
466 };
467 let stream_result = tokio::time::timeout(timeout_dur, fut).await;
468
469 match stream_result {
470 Ok(Ok(mut rx)) => {
471 loop {
473 match tokio::time::timeout(timeout_dur, rx.recv()).await {
474 Ok(Some(Ok(resp_data))) => {
475 handler.send_response(resp_data, false);
476 }
477 Ok(Some(Err(err))) => {
478 handler.send_error(&err);
479 break;
480 }
481 Ok(None) => {
482 handler.finalize();
484 break;
485 }
486 Err(_) => {
487 handler.send_error(&format!(
488 "EC-TIME: stream timeout ({}ms)",
489 timeout_ms
490 ));
491 break;
492 }
493 }
494 }
495 }
496 Ok(Err(err)) => handler.send_error(&err),
497 Err(_) => {
498 handler.send_error(&format!("EC-TIME: timeout ({}ms)", timeout_ms))
499 }
500 }
501 }
502
503 self.cleanup_handler(&msg.reqid);
504 }
505
506 fn handle_response(&self, msg: RpcMessage) {
507 let inner = self.inner.lock().unwrap();
508 if let Some(tx) = inner.pending_responses.get(&msg.resid) {
509 let is_done = !msg.cont;
510 let _ = tx.try_send(msg.clone());
511 if is_done {
512 drop(inner);
513 let mut inner = self.inner.lock().unwrap();
514 inner.pending_responses.remove(&msg.resid);
515 }
516 }
517 }
518
519 fn handle_cancel_request(&self, req_id: &str) {
520 let inner = self.inner.lock().unwrap();
521 if let Some(handler) = inner.active_handlers.get(req_id) {
522 handler.cancel();
523 }
524 }
525
526 fn cleanup_handler(&self, req_id: &str) {
527 if req_id.is_empty() {
528 return;
529 }
530 let mut inner = self.inner.lock().unwrap();
531 inner.active_handlers.remove(req_id);
532 }
533}
534
535#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[tokio::test]
544 async fn test_register_and_call_handler() {
545 let (engine, mut output_rx) = WshRpcEngine::new();
546
547 engine.register_handler(
548 "echo",
549 Box::new(|data, _ctx| {
550 Box::pin(async move { Ok(Some(data)) })
551 }),
552 );
553
554 let msg = RpcMessage {
555 command: "echo".to_string(),
556 reqid: "req-1".to_string(),
557 data: Some(serde_json::json!({"hello": "world"})),
558 ..Default::default()
559 };
560 engine.handle_message(msg);
561
562 let resp = tokio::time::timeout(
564 std::time::Duration::from_secs(1),
565 output_rx.recv(),
566 )
567 .await
568 .unwrap()
569 .unwrap();
570
571 assert_eq!(resp.resid, "req-1");
572 assert!(!resp.cont);
573 assert_eq!(resp.data, Some(serde_json::json!({"hello": "world"})));
574 }
575
576 #[tokio::test]
577 async fn test_unknown_command_returns_error() {
578 let (engine, mut output_rx) = WshRpcEngine::new();
579
580 let msg = RpcMessage {
581 command: "nonexistent".to_string(),
582 reqid: "req-2".to_string(),
583 ..Default::default()
584 };
585 engine.handle_message(msg);
586
587 let resp = tokio::time::timeout(
588 std::time::Duration::from_secs(1),
589 output_rx.recv(),
590 )
591 .await
592 .unwrap()
593 .unwrap();
594
595 assert_eq!(resp.resid, "req-2");
596 assert!(resp.error.contains("unknown command"));
597 }
598
599 #[tokio::test]
600 async fn test_handler_error_returns_error_response() {
601 let (engine, mut output_rx) = WshRpcEngine::new();
602
603 engine.register_handler(
604 "failme",
605 Box::new(|_data, _ctx| {
606 Box::pin(async move { Err("something went wrong".to_string()) })
607 }),
608 );
609
610 let msg = RpcMessage {
611 command: "failme".to_string(),
612 reqid: "req-3".to_string(),
613 ..Default::default()
614 };
615 engine.handle_message(msg);
616
617 let resp = tokio::time::timeout(
618 std::time::Duration::from_secs(1),
619 output_rx.recv(),
620 )
621 .await
622 .unwrap()
623 .unwrap();
624
625 assert_eq!(resp.error, "something went wrong");
626 }
627
628 #[tokio::test]
629 async fn test_send_command_roundtrip() {
630 let (engine, mut output_rx) = WshRpcEngine::new();
631
632 let engine_clone = engine.clone();
634 tokio::spawn(async move {
635 if let Some(msg) = output_rx.recv().await {
636 let resp = RpcMessage {
638 resid: msg.reqid.clone(),
639 data: msg.data.clone(),
640 ..Default::default()
641 };
642 engine_clone.handle_message(resp);
643 }
644 });
645
646 let opts = RpcOpts {
647 timeout: 1000,
648 ..Default::default()
649 };
650 let result = engine
651 .send_command("test", serde_json::json!(42), &opts)
652 .await;
653
654 assert!(result.is_ok());
655 assert_eq!(result.unwrap(), serde_json::json!(42));
656 }
657
658 #[tokio::test]
659 async fn test_stream_handler() {
660 let (engine, mut output_rx) = WshRpcEngine::new();
661
662 engine.register_stream_handler(
663 "counter",
664 Box::new(|_data, _ctx| {
665 Box::pin(async move {
666 let (tx, rx) = mpsc::channel(8);
667 tokio::spawn(async move {
668 for i in 0..3 {
669 let _ = tx.send(Ok(Some(serde_json::json!(i)))).await;
670 }
671 });
673 Ok(rx)
674 })
675 }),
676 );
677
678 let msg = RpcMessage {
679 command: "counter".to_string(),
680 reqid: "req-stream".to_string(),
681 ..Default::default()
682 };
683 engine.handle_message(msg);
684
685 let mut responses = Vec::new();
687 for _ in 0..4 {
688 match tokio::time::timeout(
690 std::time::Duration::from_secs(2),
691 output_rx.recv(),
692 )
693 .await
694 {
695 Ok(Some(resp)) => responses.push(resp),
696 _ => break,
697 }
698 }
699
700 assert!(responses.len() >= 3);
702 for resp in &responses[..3] {
704 assert!(resp.cont);
705 }
706 if responses.len() == 4 {
708 assert!(!responses[3].cont);
709 }
710 }
711
712 #[tokio::test]
713 async fn test_cancel_request() {
714 let (engine, mut output_rx) = WshRpcEngine::new();
715
716 let (started_tx, started_rx) = tokio::sync::oneshot::channel::<()>();
717 engine.register_handler(
718 "slow",
719 Box::new(move |_data, _ctx| {
720 Box::pin(async move {
721 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
724 Ok(Some(serde_json::json!("done")))
725 })
726 }),
727 );
728
729 let msg = RpcMessage {
731 command: "slow".to_string(),
732 reqid: "req-cancel".to_string(),
733 timeout: 10000,
734 ..Default::default()
735 };
736 engine.handle_message(msg);
737
738 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
740 let cancel_msg = RpcMessage {
741 cancel: true,
742 reqid: "req-cancel".to_string(),
743 ..Default::default()
744 };
745 engine.handle_message(cancel_msg);
746
747 let resp = tokio::time::timeout(
750 std::time::Duration::from_secs(12),
751 output_rx.recv(),
752 )
753 .await;
754 assert!(resp.is_ok());
755 drop(started_tx);
757 drop(started_rx);
758 }
759
760 #[tokio::test]
761 async fn test_send_command_no_response() {
762 let (engine, mut output_rx) = WshRpcEngine::new();
763
764 engine.send_command_no_response("notify", serde_json::json!({"msg": "hi"}), "");
765
766 let msg = tokio::time::timeout(
767 std::time::Duration::from_millis(100),
768 output_rx.recv(),
769 )
770 .await
771 .unwrap()
772 .unwrap();
773
774 assert_eq!(msg.command, "notify");
775 assert!(msg.reqid.is_empty());
776 }
777
778 #[tokio::test]
779 async fn test_auth_token() {
780 let (engine, _output_rx) = WshRpcEngine::new();
781 assert!(engine.get_auth_token().is_empty());
782
783 engine.set_auth_token("my-secret-token");
784 assert_eq!(engine.get_auth_token(), "my-secret-token");
785 }
786
787 #[tokio::test]
788 async fn test_rpc_context() {
789 let (engine, _output_rx) = WshRpcEngine::new();
790
791 let ctx = RpcContext {
792 client_type: "connserver".to_string(),
793 blockid: "blk-1".to_string(),
794 ..Default::default()
795 };
796 engine.set_rpc_context(ctx);
797
798 engine.register_handler(
800 "checkctx",
801 Box::new(|_data, ctx| {
802 Box::pin(async move {
803 Ok(Some(serde_json::json!({
804 "ctype": ctx.client_type,
805 "blockid": ctx.blockid,
806 })))
807 })
808 }),
809 );
810
811 let msg = RpcMessage {
812 command: "checkctx".to_string(),
813 reqid: "req-ctx".to_string(),
814 ..Default::default()
815 };
816 engine.handle_message(msg);
817
818 }
821}